{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports & Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:04.619453Z",
     "start_time": "2018-12-08T23:57:04.488154Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from time import time\n",
    "import warnings\n",
    "from collections import Counter\n",
    "import logging\n",
    "from ast import literal_eval as make_tuple\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from gensim.models import Word2Vec, KeyedVectors\n",
    "from gensim.models.word2vec import LineSentence\n",
    "import word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:05.049257Z",
     "start_time": "2018-12-08T23:57:05.040701Z"
    }
   },
   "outputs": [],
   "source": [
    "pd.set_option('display.expand_frame_repr', False)\n",
    "warnings.filterwarnings('ignore')\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:05.244408Z",
     "start_time": "2018-12-08T23:57:05.240318Z"
    }
   },
   "outputs": [],
   "source": [
    "def format_time(t):\n",
    "    m, s = divmod(t, 60)\n",
    "    h, m = divmod(m, 60)\n",
    "    return '{:02.0f}:{:02.0f}:{:02.0f}'.format(h, m, s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logging Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:06.423935Z",
     "start_time": "2018-12-08T23:57:06.421773Z"
    }
   },
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "        filename='logs/word2vec.log',\n",
    "        level=logging.DEBUG,\n",
    "        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',\n",
    "        datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:34.969991Z",
     "start_time": "2018-12-08T23:57:34.967461Z"
    }
   },
   "outputs": [],
   "source": [
    "analogies_path = Path().cwd().parent / 'data' / 'analogies' / 'analogies-en.txt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up Sentence Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:57.298178Z",
     "start_time": "2018-12-08T23:57:57.289388Z"
    }
   },
   "outputs": [],
   "source": [
    "NGRAMS = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To facilitate memory-efficient text ingestion, the LineSentence class creates a generator from individual sentences contained in the provided text file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:57:58.496781Z",
     "start_time": "2018-12-08T23:57:58.494515Z"
    }
   },
   "outputs": [],
   "source": [
    "sentence_path = Path('data', 'ngrams', f'ngrams_{NGRAMS}.txt')\n",
    "sentences = LineSentence(sentence_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train word2vec Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [gensim.models.word2vec](https://radimrehurek.com/gensim/models/word2vec.html) class implements the skipgram and CBOW architectures introduced above. The notebook [word2vec](../03_word2vec.ipynb) contains additional implementation detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:09:31.218671Z",
     "start_time": "2018-12-08T23:58:43.716464Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Duration: 00:10:47\n"
     ]
    }
   ],
   "source": [
    "start = time()\n",
    "model = Word2Vec(sentences,\n",
    "                 sg=1,          # 1 for skip-gram; otherwise CBOW\n",
    "                 hs=0,          # hierarchical softmax if 1, negative sampling if 0\n",
    "                 size=300,      # Vector dimensionality\n",
    "                 window=3,      # Max distance betw. current and predicted word\n",
    "                 min_count=50,  # Ignore words with lower frequency\n",
    "                 negative=10,    # noise word count for negative sampling\n",
    "                 workers=8,     # no threads \n",
    "                 iter=1,        # no epochs = iterations over corpus\n",
    "                 alpha=0.025,   # initial learning rate\n",
    "                 min_alpha=0.0001 # final learning rate\n",
    "                ) \n",
    "print('Duration:', format_time(time() - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Persist model & vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:10:01.380925Z",
     "start_time": "2018-12-09T00:10:01.143768Z"
    }
   },
   "outputs": [],
   "source": [
    "model.save('models/baseline/word2vec.model')\n",
    "model.wv.save('models/baseline/word_vectors.bin')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load model and vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T00:45:27.525905Z",
     "start_time": "2018-12-10T00:45:27.171700Z"
    }
   },
   "outputs": [],
   "source": [
    "model = Word2Vec.load('models/archive/word2vec.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T22:53:13.020767Z",
     "start_time": "2018-12-08T22:53:12.843245Z"
    }
   },
   "outputs": [],
   "source": [
    "wv = KeyedVectors.load('models/baseline/word_vectors.bin')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:11:04.596716Z",
     "start_time": "2018-12-09T00:11:04.539228Z"
    }
   },
   "outputs": [],
   "source": [
    "vocab = []\n",
    "for k, _ in model.wv.vocab.items():\n",
    "    v_ = model.wv.vocab[k]\n",
    "    vocab.append([k, v_.index, v_.count])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:11:04.905084Z",
     "start_time": "2018-12-09T00:11:04.868230Z"
    }
   },
   "outputs": [],
   "source": [
    "vocab = (pd.DataFrame(vocab, \n",
    "                     columns=['token', 'idx', 'count'])\n",
    "         .sort_values('count', ascending=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:11:06.691657Z",
     "start_time": "2018-12-09T00:11:06.679881Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Int64Index: 50491 entries, 104 to 46372\n",
      "Data columns (total 3 columns):\n",
      "token    50491 non-null object\n",
      "idx      50491 non-null int64\n",
      "count    50491 non-null int64\n",
      "dtypes: int64(2), object(1)\n",
      "memory usage: 1.5+ MB\n"
     ]
    }
   ],
   "source": [
    "vocab.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:11:07.220241Z",
     "start_time": "2018-12-09T00:11:07.202935Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>idx</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>104</th>\n",
       "      <td>million</td>\n",
       "      <td>0</td>\n",
       "      <td>2340243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>business</td>\n",
       "      <td>1</td>\n",
       "      <td>1700662</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>december</td>\n",
       "      <td>2</td>\n",
       "      <td>1513533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>627</th>\n",
       "      <td>company</td>\n",
       "      <td>3</td>\n",
       "      <td>1490752</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>477</th>\n",
       "      <td>products</td>\n",
       "      <td>4</td>\n",
       "      <td>1368711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1071</th>\n",
       "      <td>net</td>\n",
       "      <td>5</td>\n",
       "      <td>1253343</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>market</td>\n",
       "      <td>6</td>\n",
       "      <td>1149048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>380</th>\n",
       "      <td>including</td>\n",
       "      <td>7</td>\n",
       "      <td>1110482</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>381</th>\n",
       "      <td>sales</td>\n",
       "      <td>8</td>\n",
       "      <td>1098312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>60</th>\n",
       "      <td>costs</td>\n",
       "      <td>9</td>\n",
       "      <td>1020383</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          token  idx    count\n",
       "104     million    0  2340243\n",
       "0      business    1  1700662\n",
       "66     december    2  1513533\n",
       "627     company    3  1490752\n",
       "477    products    4  1368711\n",
       "1071        net    5  1253343\n",
       "145      market    6  1149048\n",
       "380   including    7  1110482\n",
       "381       sales    8  1098312\n",
       "60        costs    9  1020383"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-09T00:11:14.683574Z",
     "start_time": "2018-12-09T00:11:14.648032Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count      50491\n",
       "mean        5110\n",
       "std        37525\n",
       "min           50\n",
       "10%           61\n",
       "20%           78\n",
       "30.0%        102\n",
       "40%          137\n",
       "50%          195\n",
       "60%          300\n",
       "70%          522\n",
       "80%         1164\n",
       "90%         4578\n",
       "max      2340243\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab['count'].describe(percentiles=np.arange(.1, 1, .1)).astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Analogies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T04:38:54.485888Z",
     "start_time": "2018-12-10T04:38:54.482447Z"
    }
   },
   "outputs": [],
   "source": [
    "def eval_analogies(w2v, max_vocab=15000):\n",
    "    accuracy = w2v.wv.accuracy(ANALOGIES_PATH,\n",
    "                               restrict_vocab=15000,\n",
    "                               case_insensitive=True)\n",
    "    return (pd.DataFrame([[c['section'],\n",
    "                        len(c['correct']),\n",
    "                        len(c['incorrect'])] for c in accuracy],\n",
    "                      columns=['category', 'correct', 'incorrect'])\n",
    "          .assign(average=lambda x: \n",
    "                  x.correct.div(x.correct.add(x.incorrect))))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:21:32.500459Z",
     "start_time": "2018-12-08T23:21:32.498477Z"
    }
   },
   "outputs": [],
   "source": [
    "def total_accuracy(w2v):\n",
    "    df = eval_analogies(w2v)\n",
    "    return df.loc[df.category == 'total', ['correct', 'incorrect', 'average']].squeeze().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T00:45:44.852024Z",
     "start_time": "2018-12-10T00:45:38.732034Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>category</th>\n",
       "      <th>correct</th>\n",
       "      <th>incorrect</th>\n",
       "      <th>average</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>capital-common-countries</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>capital-world</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>city-in-state</td>\n",
       "      <td>140</td>\n",
       "      <td>390</td>\n",
       "      <td>0.264151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>currency</td>\n",
       "      <td>2</td>\n",
       "      <td>26</td>\n",
       "      <td>0.071429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>family</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>gram1-adjective-to-adverb</td>\n",
       "      <td>48</td>\n",
       "      <td>134</td>\n",
       "      <td>0.263736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>gram2-opposite</td>\n",
       "      <td>23</td>\n",
       "      <td>67</td>\n",
       "      <td>0.255556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>gram3-comparative</td>\n",
       "      <td>240</td>\n",
       "      <td>222</td>\n",
       "      <td>0.519481</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>gram4-superlative</td>\n",
       "      <td>19</td>\n",
       "      <td>53</td>\n",
       "      <td>0.263889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>gram5-present-participle</td>\n",
       "      <td>90</td>\n",
       "      <td>182</td>\n",
       "      <td>0.330882</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>gram6-nationality-adjective</td>\n",
       "      <td>250</td>\n",
       "      <td>130</td>\n",
       "      <td>0.657895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>gram7-past-tense</td>\n",
       "      <td>94</td>\n",
       "      <td>286</td>\n",
       "      <td>0.247368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>gram8-plural</td>\n",
       "      <td>87</td>\n",
       "      <td>69</td>\n",
       "      <td>0.557692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>gram9-plural-verbs</td>\n",
       "      <td>72</td>\n",
       "      <td>138</td>\n",
       "      <td>0.342857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>total</td>\n",
       "      <td>1067</td>\n",
       "      <td>1701</td>\n",
       "      <td>0.385477</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       category  correct  incorrect   average\n",
       "0      capital-common-countries        2          4  0.333333\n",
       "1                 capital-world        0          0  0.000000\n",
       "2                 city-in-state      140        390  0.264151\n",
       "3                      currency        2         26  0.071429\n",
       "4                        family        0          0  0.000000\n",
       "5     gram1-adjective-to-adverb       48        134  0.263736\n",
       "6                gram2-opposite       23         67  0.255556\n",
       "7             gram3-comparative      240        222  0.519481\n",
       "8             gram4-superlative       19         53  0.263889\n",
       "9      gram5-present-participle       90        182  0.330882\n",
       "10  gram6-nationality-adjective      250        130  0.657895\n",
       "11             gram7-past-tense       94        286  0.247368\n",
       "12                 gram8-plural       87         69  0.557692\n",
       "13           gram9-plural-verbs       72        138  0.342857\n",
       "14                        total     1067       1701  0.385477"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy = eval_analogies(model)\n",
    "accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Validate Vector Arithmetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T01:00:35.772447Z",
     "start_time": "2018-12-10T01:00:35.756869Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>:</td>\n",
       "      <td>capital-common-countries</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>athens</td>\n",
       "      <td>greece</td>\n",
       "      <td>baghdad</td>\n",
       "      <td>iraq</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>athens</td>\n",
       "      <td>greece</td>\n",
       "      <td>bangkok</td>\n",
       "      <td>thailand</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>athens</td>\n",
       "      <td>greece</td>\n",
       "      <td>beijing</td>\n",
       "      <td>china</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>athens</td>\n",
       "      <td>greece</td>\n",
       "      <td>berlin</td>\n",
       "      <td>germany</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        0                         1        2         3\n",
       "0       :  capital-common-countries      NaN       NaN\n",
       "1  athens                    greece  baghdad      iraq\n",
       "2  athens                    greece  bangkok  thailand\n",
       "3  athens                    greece  beijing     china\n",
       "4  athens                    greece   berlin   germany"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(ANALOGIES_PATH, header=None, sep=' ').head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T08:11:19.340922Z",
     "start_time": "2018-12-10T08:11:19.334225Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                  term  similarity\n",
      "0              android    0.600454\n",
      "1           smartphone    0.581685\n",
      "2                  app    0.559129\n",
      "3          smartphones    0.533848\n",
      "4  smartphones_tablets    0.526129\n",
      "5             handsets    0.514813\n",
      "6         smart_phones    0.512868\n",
      "7                apple    0.507795\n",
      "8                 apps    0.505517\n",
      "9              handset    0.491526\n"
     ]
    }
   ],
   "source": [
    "sims=model.wv.most_similar(positive=['iphone'], \n",
    "                           restrict_vocab=15000)\n",
    "print(pd.DataFrame(sims, columns=['term', 'similarity']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-10T08:14:19.395370Z",
     "start_time": "2018-12-10T08:14:19.381754Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             term  similarity\n",
      "0  united_kingdom    0.606630\n",
      "1         germany    0.585644\n",
      "2     netherlands    0.578868\n",
      "3           italy    0.547168\n",
      "4           india    0.545213\n",
      "5           spain    0.539029\n",
      "6       singapore    0.535106\n",
      "7       australia    0.525464\n",
      "8         belgium    0.523677\n",
      "9          sweden    0.510462\n"
     ]
    }
   ],
   "source": [
    "analogy = model.wv.most_similar(positive=['france', 'london'], \n",
    "                                negative=['paris'], \n",
    "                                restrict_vocab=15000)\n",
    "print(pd.DataFrame(analogy, columns=['term', 'similarity']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check similarity for random words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T23:10:41.702789Z",
     "start_time": "2018-12-08T23:10:41.640280Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
      "  if np.issubdtype(vec.dtype, np.int):\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>staff</th>\n",
       "      <th>enables</th>\n",
       "      <th>times</th>\n",
       "      <th>fees</th>\n",
       "      <th>sources</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>personnel</td>\n",
       "      <td>allows</td>\n",
       "      <td>twice</td>\n",
       "      <td>fee</td>\n",
       "      <td>source</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>team</td>\n",
       "      <td>enabling</td>\n",
       "      <td>standpoint_advantageous</td>\n",
       "      <td>professional_fees</td>\n",
       "      <td>primary_source</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>teams</td>\n",
       "      <td>helps</td>\n",
       "      <td>vimovo_orange_book</td>\n",
       "      <td>checkcard</td>\n",
       "      <td>sourced</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>professionals</td>\n",
       "      <td>enable</td>\n",
       "      <td>millisecond</td>\n",
       "      <td>commissions</td>\n",
       "      <td>readily_available</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>staffed</td>\n",
       "      <td>allowing</td>\n",
       "      <td>saturdays</td>\n",
       "      <td>atm_debit_card</td>\n",
       "      <td>internally_generated</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>hiring</td>\n",
       "      <td>enabled</td>\n",
       "      <td>assets_liabilities_react_differently</td>\n",
       "      <td>gds_reservation_booking</td>\n",
       "      <td>generated</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>consultants</td>\n",
       "      <td>allow</td>\n",
       "      <td>twice_weekly</td>\n",
       "      <td>interchange_fees_swipe</td>\n",
       "      <td>biological_contaminants_pollen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>hired</td>\n",
       "      <td>leverages</td>\n",
       "      <td>day</td>\n",
       "      <td>noticing</td>\n",
       "      <td>repair_reconstruct_damaged</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>engineers</td>\n",
       "      <td>lets</td>\n",
       "      <td>weekdays</td>\n",
       "      <td>nonsufficient</td>\n",
       "      <td>alternative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>salespeople</td>\n",
       "      <td>easy</td>\n",
       "      <td>uvb</td>\n",
       "      <td>bno_usci_cper_usag</td>\n",
       "      <td>znse</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           staff    enables                                 times                     fees                         sources\n",
       "0      personnel     allows                                 twice                      fee                          source\n",
       "1           team   enabling               standpoint_advantageous        professional_fees                  primary_source\n",
       "2          teams      helps                    vimovo_orange_book                checkcard                         sourced\n",
       "3  professionals     enable                           millisecond              commissions               readily_available\n",
       "4        staffed   allowing                             saturdays           atm_debit_card            internally_generated\n",
       "5         hiring    enabled  assets_liabilities_react_differently  gds_reservation_booking                       generated\n",
       "6    consultants      allow                          twice_weekly   interchange_fees_swipe  biological_contaminants_pollen\n",
       "7          hired  leverages                                   day                 noticing      repair_reconstruct_damaged\n",
       "8      engineers       lets                              weekdays            nonsufficient                     alternative\n",
       "9    salespeople       easy                                   uvb       bno_usci_cper_usag                            znse"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "VALID_SET = 5  # Random set of words to get nearest neighbors for\n",
    "VALID_WINDOW = 100  # Most frequent words to draw validation set from\n",
    "valid_examples = np.random.choice(VALID_WINDOW, size=VALID_SET, replace=False)\n",
    "similars = pd.DataFrame()\n",
    "\n",
    "for id in sorted(valid_examples):\n",
    "    word = vocab.loc[id, 'token']\n",
    "    similars[word] = [s[0] for s in model.wv.most_similar(word)]\n",
    "similars"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continue Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracies = (eval_analogies(model)\n",
    "              .set_index('category')\n",
    "              .average\n",
    "              .to_frame('baseline'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-08T21:26:29.866811Z",
     "start_time": "2018-12-08T20:10:12.950824Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/ipykernel_launcher.py:5: DeprecationWarning: Call to deprecated `accuracy` (Method will be removed in 4.0.0, use self.evaluate_word_analogies() instead).\n",
      "  \"\"\"\n",
      "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
      "  if np.issubdtype(vec.dtype, np.int):\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 | Duration: 464.0 | Accuracy: 28.93% \n",
      "2 | Duration: 457.8 | Accuracy: 28.83% \n",
      "3 | Duration: 459.2 | Accuracy: 28.97% \n",
      "4 | Duration: 456.9 | Accuracy: 28.60% \n",
      "5 | Duration: 457.4 | Accuracy: 29.69% \n",
      "6 | Duration: 456.8 | Accuracy: 29.40% \n",
      "7 | Duration: 457.7 | Accuracy: 29.91% \n",
      "8 | Duration: 456.4 | Accuracy: 29.61% \n",
      "9 | Duration: 456.1 | Accuracy: 29.37% \n",
      "10 | Duration: 454.6 | Accuracy: 29.17% \n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 11):\n",
    "    start = time()\n",
    "    model.train(sentences, epochs=1, total_examples=model.corpus_count)\n",
    "    accuracy = eval_analogies(model).set_index('category').average\n",
    "    accuracies = accuracies.join(accuracy.to_frame(f'{n}'))\n",
    "    print(f'{i} | Duration: {format_time(time() - start)} | Accuracy: {accuracy.total:.2%}')\n",
    "    model.save(f'word2vec/models/word2vec_{i}.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.wv.save('word_vectors_final.bin')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
